{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from pathlib import Path\n",
    "import mediapy\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "working_dir = Path.cwd()\n",
    "while working_dir.name != 'gpudrive':\n",
    "    working_dir = working_dir.parent\n",
    "    if working_dir == Path.home():\n",
    "        raise FileNotFoundError(\"Base directory 'gpudrive' not found\")\n",
    "os.chdir(working_dir)\n",
    "\n",
    "from gpudrive.env.dataset import SceneDataLoader\n",
    "from gpudrive.env.config import EnvConfig\n",
    "from gpudrive.env.env_torch import GPUDriveTorchEnv\n",
    "from gpudrive.visualize.utils import img_from_fig\n",
    "\n",
    "import logging\n",
    "logging.basicConfig(level=logging.INFO)\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = SceneDataLoader(\n",
    "    root=\"data/processed/examples\",\n",
    "    batch_size=4, # Number of worlds\n",
    "    dataset_size=1000,\n",
    "    sample_with_replacement=False,\n",
    "    shuffle=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from examples.experimental.eval_utils import load_policy, rollout\n",
    "\n",
    "policy = load_policy(\n",
    "    path_to_cpt='examples/experimental/models',\n",
    "    model_name='',\n",
    "    device='cpu'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check that the model weights are not random\n",
    "# for name, param in policy.state_dict().items():\n",
    "#     print(f\"{name} - Mean: {param.mean():.4f}, Std: {param.std():.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "policy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPUDriveTorchEnv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = GPUDriveTorchEnv(\n",
    "    config=EnvConfig(),\n",
    "    data_loader=train_loader,\n",
    "    max_cont_agents=64, \n",
    "    device=\"cpu\",\n",
    ")\n",
    "\n",
    "print(env.data_batch)\n",
    "\n",
    "obs = env.reset()[env.cont_agent_mask]\n",
    "\n",
    "print(f'observation_space: {env.observation_space}')\n",
    "print(f'obs shape: {obs.shape}')\n",
    "print(f'obs dtype: {obs.dtype} \\n')\n",
    "\n",
    "print(f'action_space: {env.action_space}')\n",
    "\n",
    "plt.hist(obs.flatten());"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show simulator to make sure we're at the same state\n",
    "env.vis.figsize = (5, 5)\n",
    "sim_states = env.vis.plot_simulator_state(\n",
    "    env_indices=[0],\n",
    "    zoom_radius=100,\n",
    "    time_steps=[0],\n",
    ")\n",
    "\n",
    "sim_states[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "( \n",
    " goal_achieved_count,\n",
    " frac_goal_achieved,\n",
    " collided_count,\n",
    " frac_collided,\n",
    " off_road_count,\n",
    " frac_off_road,\n",
    " not_goal_nor_crash_count,\n",
    " frac_not_goal_nor_crash_per_scene,\n",
    " controlled_agents_per_scene,\n",
    " sim_state_frames,\n",
    " agent_positions,\n",
    " episode_lengths\n",
    ") = rollout(\n",
    "    env=env, \n",
    "    policy=policy, \n",
    "    device='cpu', \n",
    "    render_sim_state=True,\n",
    "    zoom_radius=100,\n",
    "    deterministic=True,\n",
    ")\n",
    "\n",
    "print(f'\\n Results: \\n')\n",
    "print(f'Goal achieved: {frac_goal_achieved}')\n",
    "print(f'Collided: {frac_collided}')\n",
    "print(f'Off road: {frac_off_road}')\n",
    "print(f'Not goal nor crashed: {frac_not_goal_nor_crash_per_scene}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show rollout videos\n",
    "mediapy.show_videos(sim_state_frames, fps=15, codec='gif')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gpudrive",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
